import json
import math

import numpy as np
import sys
from wenet_gxl import AiConstant

logger = AiConstant.AI_LOGGER()


def _load_json_cmvn(json_cmvn_file: str):
    """
    加载cmvn文件, 并处理成按特征的均值和方差
    """
    with open(json_cmvn_file) as f:
        cmvn_stats = json.load(f)

    means = cmvn_stats['mean_stat']  # 存储的是每个特征的所有帧元素的累加
    variance = cmvn_stats['var_stat']  # 存储的是每个特征的所有帧元素的平方和
    count = cmvn_stats['frame_num']
    for i in range(len(means)):
        means[i] /= count
        variance[i] = variance[i] / count - means[i] * means[i]
        if variance[i] < 1.0e-20:
            variance[i] = 1.0e-20
        variance[i] = 1.0 / math.sqrt(variance[i])  # 得到标准差的倒数,用于归一化
    cmvn = np.array([means, variance])
    return cmvn


def _load_kaldi_cmvn(kaldi_cmvn_file: str):
    """
    Load the kaldi format cmvn stats file and calculate cmvn
    文件仅有一行:
    [ mean1 mean2 .. count var1 var2 .. 0 ]
    Args:
        kaldi_cmvn_file:  kaldi text style global cmvn file, which
           is generated by:
           compute-cmvn-stats --binary=false scp:feats.scp global_cmvn

    Returns:
        a numpy array of [means, vars]
    """
    # kaldi binary file start with '\0B'
    means = []
    variance = []
    with open(kaldi_cmvn_file, 'r') as fid:
        # kaldi binary file start with '\0B'
        if fid.read(2) == '\0B':
            logger.error('kaldi cmvn binary file is not supported, please '
                         'recompute it by: compute-cmvn-stats --binary=false '
                         ' scp:feats.scp global_cmvn')
            sys.exit(1)
        fid.seek(0)
        arr = fid.read().split()
        assert (arr[0] == '[')
        assert (arr[-2] == '0')
        assert (arr[-1] == ']')
        feat_dim = int((len(arr) - 2 - 2) / 2)
        for i in range(1, feat_dim + 1):
            means.append(float(arr[i]))
        count = float(arr[feat_dim + 1])
        for i in range(feat_dim + 2, 2 * feat_dim + 2):
            variance.append(float(arr[i]))
    for i in range(len(means)):
        means[i] /= count
        variance[i] = variance[i] / count - means[i] * means[i]
        if variance[i] < 1.0e-20:
            variance[i] = 1.0e-20
        variance[i] = 1.0 / math.sqrt(variance[i])
    cmvn = np.array([means, variance])
    return cmvn


def load_cmvn(cmvn_file, is_json):
    """
    is_json: true: json, false: kaldi
    """
    if is_json:
        cmvn = _load_json_cmvn(cmvn_file)
    else:
        cmvn = _load_kaldi_cmvn(cmvn_file)
    return cmvn[0], cmvn[1]
